import argparse
import pandas as pd
from transformers import pipeline
import torch
from datasets import Dataset
import warnings

def get_labels(batch, classes, template, classifier):
    """
    Process a batch of texts and return a dictionary with classification scores for each class.
    The function assumes that `classifier` can process batches of texts.
    """
    # Assuming `batch["text"]` is a list of texts
    texts = batch["text"]
    
    # Call the classifier with batched texts
    outputs = classifier(texts, classes, hypothesis_template=template, multi_label=True)
    
    # Initialize a dictionary to store outputs
    results = {template.format(label): [] for label in classes}
    
    # Populate the dictionary with scores from classifier outputs
    for output in outputs:
        for label, score in zip(output['labels'], output['scores']):
            results[template.format(label)].append(score)

    return results

def main(inp_file, output_loc, topics, template, text_col):
    df = pd.read_csv(inp_file)

    # Check if CUDA is available and set the device accordingly
    device = 0 if torch.cuda.is_available() else -1  # 0 for GPU, -1 for CPU
    assert device == 0

    # Initialize the classifier with the specified device
    classifier = pipeline("zero-shot-classification",
                          model="facebook/bart-large-mnli",
                          device=device)

    # Define classes for ZSC
#     topics = ['Israel',
#               'Ukraine',
#               'taxation',
#               'government spending',
#               'government corruption',
#               'crime',
#               'socialism',
#               'immigration',
#               'racism',
#               'abortion',
#               'guns',
#               'religion',
#               'the Middle East',
#               'Russia',
#               'China',
#              ]
#     topics_template = 'The text mentions {}.'

#     kind_of_news = ['US domestic news',
#                     'international news']
#     kind_of_news_template = 'This is {}.'

#     complains_against = ['The Democratic Party',
#                          'The Republican Party',
#                          'The US Federal government',
#                          'Joe Biden',
#                          'Donald Trump']
#     complains_against_template = '{} is bad.'

#     tone_of_news = ['conspiracy theory',
#                     'fake news article',
#                     'legitimate news article']
#     tone_of_news_template = 'This is a {}.'

    # heros = ['The Republican party', 'Donald Trump', 'DeSantis', 'Russia', 'RFK jr']
    # heros_template = '{} is good.'

    # villians = ['The Democratic party', 'Joe Biden', 'The war in Ukraine', 'Big business', 'The pharmaceutical industry']
    # villians_template = '{} is bad.'

    # Create a dataset from the list of texts
    df = df.reset_index(drop=True)
    dataset = Dataset.from_dict({"text": df[text_col].values})

    # Suppress specific UserWarning from DataLoader
    warnings.filterwarnings("ignore", message="Length of IterableDataset")

    # Process and save results
    topics_results = dataset.map(lambda x: get_labels(x, topics, template, classifier), batched=True)
    df = df.join(topics_results.to_pandas().drop(['text'], axis=1))
    df.to_csv(output_loc, index=False)


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="Get scores from NLI model as ZSC.")
    parser.add_argument('inp_file', type=str, help='Input file loc')
    parser.add_argument('output_loc', type=str, default='output.csv', help='Output file loc')
    # Add an input argument that will be used to define a list of strings, the topics to be used in ZSC
    parser.add_argument('topics', type=str, nargs='+', help='List of topics to be used in ZSC')
    # Add an input argument that will be used to define a string, the template to be used in ZSC
    parser.add_argument('template', type=str, help='Template to be used in ZSC. Use {} to indicate where the label should be inserted.')
    parser.add_argument('--text_col', type=str, default='clean_text', help='Column name of the text to be processed in the input file')
    parser.add_argument('--check_inps', action='store_true', help='Check whether the topics and template are correctly formatted')
    
    args = parser.parse_args()

    # Check whether the topics and template are correctly formatted
    # The topics should be a list of strings
    # The template should be a string
    if args.check_inps:
        print('Topics:', args.topics)
        print('Template:', args.template)
        check_inps = input('Are the topics and template correctly formatted? (y/n): ')
        if check_inps != 'y':
            print('Please provide the topics as a list of strings and the template as a string.')
            exit(1)

    main(args.inp_file, args.output_loc, args.topics, args.template)
